import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pdb import set_trace
from scipy.stats import sem, t

# Loading the data from the CSV files (assuming the files are named 'data1.csv' and 'data2.csv')
# df= pd.read_csv(f'all_runs_data_v2.csv')

font_size = 25
tick_size = 18

def calculate_correlations(grouped_data, x_metric, y_metric):
    correlations = {}
    for name, group in grouped_data:
        correlation = group[[x_metric, y_metric]].corr()
        correlations[name] = correlation.iloc[0, 1]  # Getting the off-diagonal value which is the Pearson correlation coefficient
    return correlations

# Function to plot scatter plot for data after a specific step
def plot_scatter_with_regression_and_metrics(df, step_threshold, use_expansion, title_suffix, 
        x_metric = 'total advertiser participating value gain', y_metric = 'total advertiser utility gain zero bid offset', 
        filter_outliers = True, filter_threshold = 0.995, 
        save_plot = False, font_size = 25, tick_size = 18):
    """
    Plots scatter plot with a linear regression line for data after a specified step,
    for experiments with or without input expansion, including R^2 and angle of the line.

    Parameters:
    - df: DataFrame containing the experiments data.
    - step_threshold: Integer, only consider data after this step.
    - use_expansion: Boolean, True for data with input expansion, False for data without.
    - title_suffix: String, additional detail to differentiate the plot title.
    """
    # Filter the DataFrame for data after the specified step and by input expansion
    df_filtered = df[(df['samples used'] > step_threshold) & (df['use_input_expansion'] == use_expansion)]

    # Filter out outliers if needed
    # Filter outliers if needed
    if filter_outliers:
        x_max = df_filtered[x_metric].quantile(filter_threshold)
        y_max = df_filtered[y_metric].quantile(filter_threshold)

        x_min = df_filtered[x_metric].quantile(1 - filter_threshold)
        y_min = df_filtered[y_metric].quantile(1 - filter_threshold)

        df_filtered = df_filtered[(df_filtered[x_metric] < x_max) & (df_filtered[y_metric] < y_max) & (df_filtered[x_metric] > x_min) & (df_filtered[y_metric] > y_min)]

    
    # Extracting the two variables of interest, flipping the axes
    y = df_filtered[y_metric] 
    x = df_filtered[x_metric]

    
    # Calculating the linear regression line
    m, b = np.polyfit(x, y, 1)  # m = slope, b = intercept

    # Calculate R^2
    y_pred = m*x + b
    ss_res = np.sum((y - y_pred) ** 2)
    ss_tot = np.sum((y - np.mean(y)) ** 2)
    r_squared = 1 - (ss_res / ss_tot)

    # Calculate angle of the line
    angle = np.arctan(m)
    angle_degrees = np.degrees(angle)

    plt.figure(figsize=(12, 8))
    plt.scatter(x, y, alpha=0.15, color = 'orange')  # Plot scatter
    plt.plot(x, y_pred, color="blue")  # Plot regression line
    remove_gain_from_legends = True # set to True if you want to remove the word 'Gain' from the legend
    
    if y_metric == 'total advertiser utility gain zero bid offset':
        ylabel = 'Total Advertiser Utility Gain'
    elif y_metric == 'total advertiser utility gain no offset':
        ylabel = 'Total Advertiser Utility Gain'
    elif y_metric == 'advertiser 0 utility gain zero bid offset':
        ylabel = 'Advertiser 0 Utility Gain'
    elif y_metric == 'advertiser 1 utility gain zero bid offset':
        ylabel = 'Advertiser 1 Utility Gain'
    elif y_metric == 'advertiser 0 utility gain no offset':
        ylabel = 'Advertiser 0 Utility Gain'
    elif y_metric == 'advertiser 1 utility gain no offset':
        ylabel = 'Advertiser 1 Utility Gain'
    else: 
        KeyError('Invalid y_metric')
    
    if x_metric == 'total advertiser participating value gain':
        xlabel = 'Total Advertiser Reward Gain'
    elif x_metric == 'advertiser 0 participating value gain':
        xlabel = 'Advertiser 0 Reward Gain'
    elif x_metric == 'advertiser 1 participating value gain':
        xlabel = 'Advertiser 1 Reward Gain'

    if remove_gain_from_legends:
        ylabel = ylabel.replace(' Gain', '')
        xlabel = xlabel.replace(' Gain', '')
    
    plt.ylabel(ylabel, fontsize=font_size)
    plt.xlabel(xlabel, fontsize=font_size)
        
    
    # Add R^2 and angle to plot
    plt.text(min(x), max(y), f'$R^2$ = {r_squared:.2f}\nSlope = {m:.2f}', fontsize=font_size,
             verticalalignment='top', bbox=dict(facecolor='white', alpha=0.5, edgecolor='black', boxstyle='round,pad=0.5'))
    
    plt.tick_params(axis='both', which='major', labelsize=tick_size) 


    plt.grid(True)
    if not save_plot:
        # No need to add title if saving the plot, as it will be added to the paper caption
        plt.title(f'{x_metric} vs {y_metric} IE: {use_expansion} FO: {filter_outliers}', fontsize=font_size - 10)
        plt.show()
    else: 
        format = 'pdf'
        savename = f'./plots/scatter_plot_{x_metric}_{y_metric}_input_expansion_{use_expansion}_filter_outliers_{filter_outliers}.{format}'
        savename = savename.replace(' ', '_')
        plt.savefig(savename)


def plot_pearson_correlation(df, save_plot = False, 
        x_metric = 'total advertiser utility gain', y_metric = 'total advertiser participating value gain', use_expansion = True,
        y_min = -1.05, y_max = 1.05, 
        font_size = 25, tick_size = 18):
    
    # Filter the DataFrame for data with and without input expansion
    df_filtered= df[df['use_input_expansion'] == use_expansion]
    
    # Group data by 'samples used' again, in case it's needed for calculating correlations
    grouped_df = df_filtered.groupby('samples used')

    # Calculate correlations with and without offset
    correlations_with_offset = calculate_correlations(grouped_df, x_metric + ' zero bid offset', y_metric)
    correlations_without_offset = calculate_correlations(grouped_df, x_metric + ' no offset', y_metric)
    

    # Convert the correlations dictionaries to Series for easier plotting
    correlations_with_offset_series = pd.Series(correlations_with_offset, name="With Offset")
    correlations_without_offset_series = pd.Series(correlations_without_offset, name="Without Offset")

    # Plotting the correlations
    plt.figure(figsize=(12, 8))

    correlations_with_offset_series.plot(label='With Offset', marker='o', color = 'orange', alpha=1)
    correlations_without_offset_series.plot(label='Without Offset', marker='x', color = 'blue', alpha=1)
    plt.legend(fontsize=font_size)


    # Set y-axis limits
    plt.ylim(y_min, y_max)

    plt.xlabel('Candidate Replies Generated', fontsize=font_size)
    plt.ylabel('Pearson Correlation Coefficient', fontsize=font_size)
    plt.grid(True)
    plt.xticks(np.arange(1, 21, step=1))  # Assuming steps range from 1 to 20

    plt.tick_params(axis='both', which='major', labelsize=tick_size)
    
    if not save_plot:
        plt.title(f'{x_metric} vs {y_metric} Correlation', fontsize=font_size - 10)
        plt.show()
    else:
        format = 'pdf'
        savename = f'./plots/pearson_correlation_{x_metric}_{y_metric}_input_expansion_{use_expansion}.{format}'
        savename = savename.replace(' ', '_')
        plt.savefig(savename)


def plot_satisfaction(df, use_expansion, y_metric='utility gain', satisfaction_threshold=0, font_size=25, tick_size=18, save_plot=False, 
                      number_of_agents=2):
    """
    Plots satisfaction as the ratio of agents with positive utility to the total number of agents,
    as a function of the number of samples used. Compares satisfaction with and without using the payment offset.

    Parameters:
    - df: DataFrame containing the experiments data.
    - use_expansion: Boolean, True for data with input expansion, False for data without.
    - y_metric: The metric used to determine satisfaction.
    - satisfaction_threshold: The threshold above which utility is considered positive.
    - font_size: Font size for plot annotations.
    - tick_size: Tick size for plot.
    - save_plot: Whether to save the plot or show it.
    """
    # Filter the DataFrame for data with and without input expansion
    df_filtered = df[df['use_input_expansion'] == use_expansion]

    # Initialize dictionaries to store satisfaction data
    satisfaction_with_offset = {}
    satisfaction_without_offset = {}

    # Calculate satisfaction for each group
    for samples, group in df_filtered.groupby('samples used'):
        total_agents = len(group) * number_of_agents
        positive_with_offset = 0 
        positive_without_offset = 0
        for agent in range(number_of_agents):
            positive_with_offset += (group[f'advertiser {agent} utility gain zero bid offset'] >= satisfaction_threshold).sum()
            positive_without_offset += (group[f'advertiser {agent} utility gain no offset'] >= satisfaction_threshold).sum()


        satisfaction_with_offset[samples] = positive_with_offset / total_agents
        satisfaction_without_offset[samples] = positive_without_offset / total_agents

    # Convert the satisfaction dictionaries to Series for easier plotting
    satisfaction_with_offset_series = pd.Series(satisfaction_with_offset, name="With Offset")
    satisfaction_without_offset_series = pd.Series(satisfaction_without_offset, name="Without Offset")

    # Plotting
    plt.figure(figsize=(12, 8))

    satisfaction_with_offset_series.plot(label='With Offset', marker='o', color='orange', alpha=0.5)
    satisfaction_without_offset_series.plot(label='Without Offset', marker='x', color='blue', alpha=0.5)
    plt.legend(fontsize=font_size)

    plt.xlabel('Samples Used', fontsize=font_size)
    plt.ylabel('Satisfaction', fontsize=font_size)
    plt.grid(True)
    plt.xticks(np.arange(1, max(df_filtered['samples used'])+1, step=1))

    plt.tick_params(axis='both', which='major', labelsize=tick_size)
    
    if save_plot:
        format = 'pdf'
        savename = f'./plots/satisfaction_{y_metric}_input_expansion_{use_expansion}.{format}'
        savename = savename.replace(' ', '_')
        plt.savefig(savename)
    else:
        plt.title(f'Satisfaction vs. Samples Used input expansion: {use_expansion}', fontsize=font_size)
        plt.show()



def plot_utility_gain(df, use_expansion, save_plot=False, 
        y_min=None, y_max=None, 
        font_size=25, tick_size=18):
    """
    Plots the total advertiser utility gain with and without the payment offset as a function of the number of generated candidate sequences.

    Parameters:
    - df: DataFrame containing the experiments data.
    - use_expansion: Boolean, True for data with input expansion, False for data without.
    - save_plot: Boolean, True to save the plot, False to show it.
    - y_min: Float, minimum value for the y-axis.
    - y_max: Float, maximum value for the y-axis.
    - font_size: Integer, font size for plot annotations.
    - tick_size: Integer, tick size for plot.
    """
    # Filter the DataFrame for data with and without input expansion
    df_filtered = df[df['use_input_expansion'] == use_expansion]

    # Group data by 'samples used'
    grouped_df = df_filtered.groupby('samples used')

    # Calculate mean utility gain with and without offset
    utility_gain_with_offset = grouped_df['total advertiser utility gain zero bid offset'].mean()
    utility_gain_without_offset = grouped_df['total advertiser utility gain no offset'].mean()

    # Plotting the utility gains
    plt.figure(figsize=(12, 8))

    utility_gain_with_offset.plot(label='With Offset', marker='o', color='orange', alpha=0.5)
    utility_gain_without_offset.plot(label='Without Offset', marker='x', color='blue', alpha=0.5)
    plt.legend(fontsize=font_size)

    # Set y-axis limits if specified
    if y_min is not None and y_max is not None:
        plt.ylim(y_min, y_max)

    plt.xlabel('Samples Used', fontsize=font_size)
    plt.ylabel('Total Advertiser Utility Gain', fontsize=font_size)
    plt.grid(True)
    plt.xticks(np.arange(1, max(df_filtered['samples used']) + 1, step=1))

    plt.tick_params(axis='both', which='major', labelsize=tick_size)

    if save_plot:
        format = 'pdf'
        savename = f'./plots/utility_gain_input_expansion_{use_expansion}.{format}'
        savename = savename.replace(' ', '_')
        plt.savefig(savename)
    else:
        plt.title(f'Total Advertiser Utility Gain vs. Samples Used (Input Expansion: {use_expansion})', fontsize=font_size)
        plt.show()


def plot_utility_gain(df, use_expansion, save_plot=False, 
        y_min=None, y_max=None, 
        font_size=25, tick_size=18, ci=0.95):
    """
    Plots the total advertiser utility gain with and without the payment offset as a function of the number of generated candidate sequences.

    Parameters:
    - df: DataFrame containing the experiments data.
    - use_expansion: Boolean, True for data with input expansion, False for data without.
    - save_plot: Boolean, True to save the plot, False to show it.
    - y_min: Float, minimum value for the y-axis.
    - y_max: Float, maximum value for the y-axis.
    - font_size: Integer, font size for plot annotations.
    - tick_size: Integer, tick size for plot.
    - ci: Float, confidence interval to calculate.
    """

    def calc_confidence_interval(data):
        n = len(data)
        m = np.mean(data)
        std_err = sem(data)
        h = std_err * t.ppf((1 + ci) / 2, n - 1)
        return m, m - h, m + h

    # Filter the DataFrame for data with and without input expansion
    df_filtered = df[df['use_input_expansion'] == use_expansion]

    # Group data by 'samples used'
    grouped_df = df_filtered.groupby('samples used')

    # Calculate mean utility gain with and without offset
    utility_gain_with_offset = grouped_df['total advertiser utility gain zero bid offset'].mean()
    utility_gain_without_offset = grouped_df['total advertiser utility gain no offset'].mean()

    # Calculate confidence intervals
    ci_with_offset = grouped_df['total advertiser utility gain zero bid offset'].apply(calc_confidence_interval)
    ci_without_offset = grouped_df['total advertiser utility gain no offset'].apply(calc_confidence_interval)

    # Extracting the mean and confidence intervals
    mean_with_offset = ci_with_offset.apply(lambda x: x[0])
    lower_with_offset = ci_with_offset.apply(lambda x: x[1])
    upper_with_offset = ci_with_offset.apply(lambda x: x[2])

    mean_without_offset = ci_without_offset.apply(lambda x: x[0])
    lower_without_offset = ci_without_offset.apply(lambda x: x[1])
    upper_without_offset = ci_without_offset.apply(lambda x: x[2])

    # Plotting the utility gains with confidence intervals
    plt.figure(figsize=(12, 8))

    plt.plot(mean_with_offset.index, mean_with_offset, label='With Offset', marker='o', color='orange', alpha=0.7)
    plt.fill_between(mean_with_offset.index, lower_with_offset, upper_with_offset, color='orange', alpha=0.2)

    plt.plot(mean_without_offset.index, mean_without_offset, label='Without Offset', marker='x', color='blue', alpha=0.7)
    plt.fill_between(mean_without_offset.index, lower_without_offset, upper_without_offset, color='blue', alpha=0.2)

    plt.legend(fontsize=font_size)

    # Set y-axis limits if specified
    # if y_min is not None and y_max is not None:
    plt.ylim(y_min, y_max)

    plt.xlabel('Candidate Sequences Generated', fontsize=font_size)
    plt.ylabel('Total Advertiser Utility', fontsize=font_size)
    plt.grid(True)
    plt.xticks(np.arange(1, max(df_filtered['samples used']) + 1, step=1))

    plt.tick_params(axis='both', which='major', labelsize=tick_size)

    if save_plot:
        format = 'pdf'
        savename = f'./plots/utility_gain_input_expansion_{use_expansion}.{format}'
        savename = savename.replace(' ', '_')
        plt.savefig(savename)
    else:
        plt.title(f'Total Advertiser Utility Gain vs. Samples Used (Input Expansion: {use_expansion})', fontsize=font_size)
        plt.show()



# Plot scatter plot for data after a specific step
if __name__ == '__main__':
    df = pd.read_csv('all_runs_data_4.csv')
    # plot_scatter_with_regression_and_metrics(df, step_threshold= 4, use_expansion= False, title_suffix= '',
    #         save_plot= True, filter_outliers= False, filter_threshold= 0.995, 
    #         y_metric= 'total advertiser utility gain no offset')

    
    # --- Utiltiy gain with and without the offset plots --- # 
    # plot_utility_gain(df, use_expansion= False, save_plot= True, y_max = 7, y_min = None)
    
    
    # # --- Single advertiser scallter plots ---- # 
    # for use_expansion in [True, False]:
    #     for filter_outliers in [True]:
    #         # for x_metric, y_metric in [('advertiser 0 participating value gain', 'advertiser 0 utility gain zero bid offset'), 
    #         #         ('advertiser 1 participating value gain', 'advertiser 1 utility gain zero bid offset'), 
    #         #         ('advertiser 0 participating value gain', 'advertiser 0 utility gain no offset'), 
    #         #         ('advertiser 1 participating value gain', 'advertiser 1 utility gain no offset'), 
    #         #         ('total advertiser participating value gain', 'total advertiser utility gain zero bid offset'), 
    #         #         ('total advertiser participating value gain', 'total advertiser utility gain no offset')]:
                
    #         for x_metric, y_metric in [('advertiser 0 participating value gain', 'advertiser 0 utility gain zero bid offset'), 
    #                 ('advertiser 0 participating value gain', 'advertiser 0 utility gain no offset'), 
    #                 ('total advertiser participating value gain', 'total advertiser utility gain zero bid offset'), 
    #                 ('total advertiser participating value gain', 'total advertiser utility gain no offset')]:
                
    #             plot_scatter_with_regression_and_metrics(df, step_threshold= 4, use_expansion= use_expansion, title_suffix= '',
    #                 save_plot= True, filter_outliers= filter_outliers, filter_threshold= 0.995, 
    #                 x_metric= x_metric, y_metric= y_metric)


    # # --- Plot Pearson correlations --- #
    for use_expansion in [True, False]:
        for x_metric, y_metric in [('advertiser 0 utility gain', 'advertiser 0 participating value gain'),
                ('advertiser 1 utility gain', 'advertiser 1 participating value gain'), 
                ('total advertiser utility gain', 'total advertiser participating value gain')]:
            plot_pearson_correlation(df, save_plot= True, 
                x_metric= x_metric, y_metric= y_metric, use_expansion= use_expansion,
                # y_min= -1.05, y_max= 1.05, 
                y_min= 0, y_max= 1.05, 
                font_size= 25, tick_size= 18)

    # plot_pearson_correlation(df, save_plot= False, 
    #     x_metric= 'advertiser 1 utility gain', y_metric= 'advertiser 1 participating value gain', use_expansion= True,
    #     y_min= -1.05, y_max= 1.05, 
    #     font_size= 25, tick_size= 18)

    # # --- Plot satisfaction --- #
    # plot_satisfaction(df, use_expansion= False, y_metric='total advertiser utility gain', satisfaction_threshold=0, font_size=25, tick_size=18, save_plot=False)